#include "bspline_opt/uniform_bspline.h"
#include "nav_msgs/Odometry.h"
#include "quadrotor_msgs/Command.h"
#include "std_msgs/Empty.h"
#include "traj_utils/Bspline.h"
#include "visualization_msgs/Marker.h"
#include <ros/ros.h>
#include <std_msgs/Int32.h>

ros::Publisher pos_cmd_pub;

quadrotor_msgs::Command cmd;
double pos_gain[3] = {0, 0, 0};
double vel_gain[3] = {0, 0, 0};

using ego_planner::UniformBspline;

bool receive_traj_ = false;
vector<UniformBspline> traj_;
double traj_duration_;
ros::Time start_time_;
int traj_id_;

// yaw control
double last_yaw_, last_yaw_dot_;
double time_forward_;

int state = 0;

void bsplineCallback(traj_utils::BsplineConstPtr msg) {
    // parse pos traj

    Eigen::MatrixXd pos_pts(3, msg->pos_pts.size());

    Eigen::VectorXd knots(msg->knots.size());
    for (size_t i = 0; i < msg->knots.size(); ++i) {
        knots(i) = msg->knots[i];
    }

    for (size_t i = 0; i < msg->pos_pts.size(); ++i) {
        pos_pts(0, i) = msg->pos_pts[i].x;
        pos_pts(1, i) = msg->pos_pts[i].y;
        pos_pts(2, i) = msg->pos_pts[i].z;
    }

    UniformBspline pos_traj(pos_pts, msg->order, 0.1);
    pos_traj.setKnot(knots);

    // parse yaw traj

    start_time_ = msg->start_time;
    traj_id_ = msg->traj_id;

    traj_.clear();
    traj_.push_back(pos_traj);
    traj_.push_back(traj_[0].getDerivative());
    traj_.push_back(traj_[1].getDerivative());

    traj_duration_ = traj_[0].getTimeSum();

    receive_traj_ = true;
}

std::pair<double, double> calculate_yaw(double t_cur, Eigen::Vector3d &pos,
                                        ros::Time &time_now,
                                        ros::Time &time_last) {
    constexpr double PI = 3.1415926;
    constexpr double YAW_DOT_MAX_PER_SEC = PI;
    // constexpr double YAW_DOT_DOT_MAX_PER_SEC = PI;
    std::pair<double, double> yaw_yawdot(0, 0);
    double yaw = 0;
    double yawdot = 0;

    Eigen::Vector3d dir =
        t_cur + time_forward_ <= traj_duration_
            ? traj_[0].evaluateDeBoorT(t_cur + time_forward_) - pos
            : traj_[0].evaluateDeBoorT(traj_duration_) - pos;
    double yaw_temp = dir.norm() > 0.1 ? atan2(dir(1), dir(0)) : last_yaw_;
    double max_yaw_change =
        YAW_DOT_MAX_PER_SEC * (time_now - time_last).toSec();
    if (yaw_temp - last_yaw_ > PI) {
        if (yaw_temp - last_yaw_ - 2 * PI < -max_yaw_change) {
            yaw = last_yaw_ - max_yaw_change;
            if (yaw < -PI)
                yaw += 2 * PI;

            yawdot = -YAW_DOT_MAX_PER_SEC;
        } else {
            yaw = yaw_temp;
            if (yaw - last_yaw_ > PI)
                yawdot = -YAW_DOT_MAX_PER_SEC;
            else
                yawdot =
                    (yaw_temp - last_yaw_) / (time_now - time_last).toSec();
        }
    } else if (yaw_temp - last_yaw_ < -PI) {
        if (yaw_temp - last_yaw_ + 2 * PI > max_yaw_change) {
            yaw = last_yaw_ + max_yaw_change;
            if (yaw > PI)
                yaw -= 2 * PI;

            yawdot = YAW_DOT_MAX_PER_SEC;
        } else {
            yaw = yaw_temp;
            if (yaw - last_yaw_ < -PI)
                yawdot = YAW_DOT_MAX_PER_SEC;
            else
                yawdot =
                    (yaw_temp - last_yaw_) / (time_now - time_last).toSec();
        }
    } else {
        if (yaw_temp - last_yaw_ < -max_yaw_change) {
            yaw = last_yaw_ - max_yaw_change;
            if (yaw < -PI)
                yaw += 2 * PI;

            yawdot = -YAW_DOT_MAX_PER_SEC;
        } else if (yaw_temp - last_yaw_ > max_yaw_change) {
            yaw = last_yaw_ + max_yaw_change;
            if (yaw > PI)
                yaw -= 2 * PI;

            yawdot = YAW_DOT_MAX_PER_SEC;
        } else {
            yaw = yaw_temp;
            if (yaw - last_yaw_ > PI)
                yawdot = -YAW_DOT_MAX_PER_SEC;
            else if (yaw - last_yaw_ < -PI)
                yawdot = YAW_DOT_MAX_PER_SEC;
            else
                yawdot =
                    (yaw_temp - last_yaw_) / (time_now - time_last).toSec();
        }
    }

    if (fabs(yaw - last_yaw_) <= max_yaw_change)
        yaw = 0.5 * last_yaw_ + 0.5 * yaw; // nieve LPF
    yawdot = 0.5 * last_yaw_dot_ + 0.5 * yawdot;
    last_yaw_ = yaw;
    last_yaw_dot_ = yawdot;

    yaw_yawdot.first = yaw;
    yaw_yawdot.second = yawdot;

    return yaw_yawdot;
}

void stateCallback(const std_msgs::Int32::ConstPtr &msg) {
    state = msg->data;
}

void cmdCallback(const ros::TimerEvent &e) {
    /* no publishing before receive traj_ */
    if (!receive_traj_)
        return;

    ros::Time time_now = ros::Time::now();
    double t_cur = (time_now - start_time_).toSec();

    Eigen::Vector3d pos(Eigen::Vector3d::Zero()), vel(Eigen::Vector3d::Zero()),
        acc(Eigen::Vector3d::Zero()), pos_f;
    std::pair<double, double> yaw_yawdot(0, 0);

    static ros::Time time_last = ros::Time::now();
    if (t_cur < traj_duration_ && t_cur >= 0.0) {
        pos = traj_[0].evaluateDeBoorT(t_cur);
        vel = traj_[1].evaluateDeBoorT(t_cur);
        acc = traj_[2].evaluateDeBoorT(t_cur);

        /*** calculate yaw ***/
        yaw_yawdot = calculate_yaw(t_cur, pos, time_now, time_last);
        /*** calculate yaw ***/

        double tf = min(traj_duration_, t_cur + 2.0);
        pos_f = traj_[0].evaluateDeBoorT(tf);
    } else if (t_cur >= traj_duration_) {
        /* hover when finish traj_ */
        pos = traj_[0].evaluateDeBoorT(traj_duration_);
        vel.setZero();
        acc.setZero();

        yaw_yawdot.first = last_yaw_;
        yaw_yawdot.second = 0;

        pos_f = pos;
        return;
    } else {
        cout << "[Traj server]: invalid time." << endl;
    }
    time_last = time_now;

    cmd.header.stamp = time_now;
    cmd.header.frame_id = "world";
    // cmd.trajectory_flag = quadrotor_msgs::Command::TRAJECTORY_STATUS_READY;
    // cmd.trajectory_id = traj_id_;

    cmd.position.x = pos(0);
    cmd.position.y = pos(1);
    cmd.position.z = pos(2);

    cmd.velocity.x = vel(0);
    cmd.velocity.y = vel(1);
    cmd.velocity.z = vel(2);

    cmd.acceleration.x = acc(0);
    cmd.acceleration.y = acc(1);
    cmd.acceleration.z = acc(2);

    cmd.yaw = yaw_yawdot.first;
    cmd.yaw_dot = yaw_yawdot.second;

    last_yaw_ = cmd.yaw;

    // betaflight mode
    cmd.mode = cmd.POSITION_MODE;
    pos_cmd_pub.publish(cmd);
}

int main(int argc, char **argv) {
    ros::init(argc, argv, "traj_server");
    // ros::NodeHandle node;
    ros::NodeHandle nh("~");

    ros::Subscriber bspline_sub =
        nh.subscribe("planning/bspline", 10, bsplineCallback);

    // ros::Subscriber state_sub = nh.subscribe("/mpc/flag", 5, stateCallback);

    pos_cmd_pub = nh.advertise<quadrotor_msgs::Command>("/bfctrl/cmd", 50);

    ros::Timer cmd_timer = nh.createTimer(ros::Duration(0.01), cmdCallback);

    nh.param("traj_server/time_forward", time_forward_, -1.0);
    last_yaw_ = 0.0;
    last_yaw_dot_ = 0.0;

    ros::Duration(1.0).sleep();

    ROS_WARN("[Traj server]: ready.");

    ros::spin();

    return 0;
}